TensorFlow/Keras¶Keras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation. Being able to go from idea to result with the least possible delay is key to doing good research.
Note 1: This is not an introduction to deep neural networks as this would explode the scope of this notebook. But we want to show you how you can implement a convoluted neural network to classify neuroimages, in our case fMRI images.
Note 2: We want to thank Anisha Keshavan as a lot of the content in this notebook is coming from here introduction notebook about Keras.
import warnings
warnings.filterwarnings("ignore")
from nilearn import plotting
%matplotlib inline
import numpy as np
import nibabel as nb
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white")
import os
import datetime
import tensorflow as tf
import plotly.graph_objects as go
from plotly import figure_factory as ff
We will load a dataset that was prepared to enable quick showcases/introductions for machine learning. It includes an anatomical template image (we will need this for visualization), as well as a 4D fMRI image from a resting state scan. The dataset ist from Zang et al.. It contains 48 subjects, where each subject did two resting-state fMRI recordings. Once with eyes open and once with eyes closed. The data was already pre-processed and is already ready for the machine learning notebooks. Note: The data diverges from the original data in the way, that we only consider the first 100 volumes for this tutorial. The original dataset had 240 volumes per run.

anat = nb.load('data/MNI152_T1_1mm.nii.gz')
func = nb.load('data/dataset_ML.nii.gz')
Let's check how the 4D fMRI image looks like via plotting its mean over time.
from nilearn.image import mean_img
from nilearn.plotting import plot_anat
plot_anat(mean_img(func), cmap='magma', colorbar=False, display_mode='x', vmax=2, annotate=False,
cut_coords=range(0, 49, 12), title='Mean value of machine learning dataset');
As in every other machine or deep learning application, we need some chunks and label variables to train the neural network. The labels are important so that we can predict what we want to classify. And the chunks are just an easy way to make sure that the training and test dataset are split in an equal/balanced way.
So, as before, we specify again which volumes of the dataset were recorded during eyes closed resting state and which ones were recorded during eyes open resting state recording.
From the dataset release we know that we have a total of 384 volumes in our dataset_ML.nii.gz file and that it's always 4 volumes of the condition eyes closed, followed by 4 volumes of the condition eyes open, etc. Therefore our labels should be as follows:
labels = np.ravel([[['closed'] * 4, ['open'] * 4] for i in range(48)])
labels[:20]
array(['closed', 'closed', 'closed', 'closed', 'open', 'open', 'open',
'open', 'closed', 'closed', 'closed', 'closed', 'open', 'open',
'open', 'open', 'closed', 'closed', 'closed', 'closed'],
dtype='<U6')
Second, the chunks variable should not switch between subjects. So, as before, we can again specify 6 chunks of 64 volumes (8 subjects), each:
chunks = np.ravel([[i] * 64 for i in range(6)])
chunks[:150]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
Convoluted neural networks are very powerful (as you will see), but the computation power to train the models can be incredibly demanding. For this reason, it's sometimes recommended to try to reduce the input space if possible.
In our case, we could try to not train the neural network only on one very thin slab (a few slices) of the brain. So, instead of taking the data matrix of the whole brain, we just take 2 slices in the region that we think is most likely to be predictive for the question at hand.
We know (or suspect) that the regions with the most predictive power are probably somewhere around the eyes and in the visual cortex. So let's try to specify a few slices that cover those regions.
So, let's try to just take a few slices around the eyes:
plot_anat(mean_img(func).slicer[...,5:-25], cmap='magma', colorbar=False,
display_mode='x', vmax=2, annotate=False, cut_coords=range(0, 49, 12),
title='Slab of the machine learning mean image');
Hmm... That doesn't seem to work. We want to cover the eyes and the visual cortex. Like this, we're too far down in the back of the head (at the Cerebellum). One solution to this is to rotate the volume.
So let's do that:
# Rotation parameters
phi = 0.35
cos = np.cos(phi)
sin = np.sin(phi)
# Compute rotation matrix around x-axis
rotation_affine = np.array([[1, 0, 0, 0],
[0, cos, -sin, 0],
[0, sin, cos, 0],
[0, 0, 0, 1]])
new_affine = rotation_affine.dot(func.affine)
# Rotate and resample image to new orientation
from nilearn.image import resample_img
new_img = nb.Nifti1Image(func.get_fdata(), new_affine)
img_rot = resample_img(new_img, func.affine, interpolation='continuous')
del func
del new_img
# Delete zero-only rows and columns
from nilearn.image import crop_img
img_crop = crop_img(img_rot)
del img_rot
Let's check if the rotation worked.
plot_anat(mean_img(img_crop), cmap='magma', colorbar=False, display_mode='x', vmax=2, annotate=False,
cut_coords=range(-20, 30, 12), title='Rotated machine learning dataset');
Perfect! And which slab should we take? Let's try the slices 12, 13 and 14.
from nilearn.plotting import plot_stat_map
img_slab = img_crop.slicer[..., 12:15, :]
plot_stat_map(mean_img(img_slab), cmap='magma', bg_img=mean_img(img_crop), colorbar=False,
display_mode='x', vmax=2, annotate=False, cut_coords=range(-20, 30, 12),
title='Slab of rotated machine learning dataset');
Perfect, the slab seems to contain exactly what we want. Now that the data is ready we can continue with the actual machine learning part.
First things first, we need to define a training and testing set. This is really important because we need to make sure that our model can generalize to new, unseen data. Here, we randomly shuffle our data, and reserve 80% of it for our training data, and the remaining 20% for testing.
So let's first get the data in the right structure for keras. For this, we need to swap some of the dimensions of our data matrix.
data = np.rollaxis(img_slab.get_fdata(), 3, 0)
data.shape
(384, 40, 56, 3)
As you can see, the goal is to have in the first dimension, the different volumes, and then the volume itself. Keep in mind, that the last dimension (here of size 2), are considered as channels in the keras model that we will be using below.
Note: To make this notebook reproducible, i.e. always leading to the "same" results. Let's set a seed point for the random split of the dataset. This should only be done for teaching purposes, but not for real research as randomness and chance are a crucial part of machine learning.
from numpy.random import seed
seed(0)
As a next step, let's create a index list that we can use to split the data and labels into training and test sets:
# Create list of indices and shuffle them
N = data.shape[0]
indices = np.arange(N)
np.random.shuffle(indices)
# Cut the dataset at 80% to create the training and test set
N_80p = int(0.8 * N)
indices_train = indices[:N_80p]
indices_test = indices[N_80p:]
# Split the data into training and test sets
X_train = data[indices_train, ...]
X_test = data[indices_test, ...]
print(X_train.shape, X_test.shape)
(307, 40, 56, 3) (77, 40, 56, 3)
We need to define a variable that holds the outcome variable (1 or 0) that indicates whether or not the resting-state images were recorded with eyes opened or closed. Luckily we have this information already stored in the labels variable above. So let's split these labels in training and test set:
y_train = labels[indices_train] == 'open'
y_test = labels[indices_test] == 'open'
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
scaler = StandardScaler()
pca = PCA()
tsne = TSNE()
X_scaled = scaler.fit_transform(X_train.reshape(len(X_train), -1))
X_pca = pca.fit_transform(X_scaled)
plt.plot(pca.explained_variance_ratio_.cumsum())
[<matplotlib.lines.Line2D at 0x181fd2440>]
y_train
array([ True, True, True, True, False, False, True, False, True,
True, False, False, False, False, True, True, True, False,
False, False, True, False, False, False, False, True, True,
True, False, True, True, False, True, True, False, False,
True, True, True, True, False, False, False, False, False,
True, False, True, True, True, True, True, True, False,
True, True, False, True, False, False, False, True, True,
False, False, False, True, False, False, False, True, True,
True, True, False, True, True, False, False, False, False,
False, True, True, True, True, False, True, False, True,
True, False, False, True, True, True, False, False, True,
True, True, False, True, True, True, True, True, False,
False, True, False, True, True, False, True, True, True,
False, False, True, True, False, True, False, False, False,
True, False, True, True, True, False, False, True, False,
False, True, True, True, False, True, True, False, True,
True, True, False, False, False, True, False, True, False,
True, False, True, True, True, True, True, False, True,
True, False, False, True, True, False, True, True, False,
False, False, False, True, True, True, False, False, False,
False, True, True, True, False, False, True, True, True,
True, True, False, True, True, True, True, True, True,
False, False, True, False, False, True, True, False, False,
True, False, False, False, False, True, True, True, True,
True, False, False, True, True, True, False, False, False,
False, True, False, True, True, False, False, True, False,
True, True, True, False, False, False, False, True, True,
False, True, False, False, True, False, False, True, False,
False, True, False, True, True, False, False, False, False,
True, True, False, False, False, True, True, False, True,
True, False, False, True, False, False, True, True, True,
False, True, True, True, False, False, True, False, False,
False, True, False, False, False, True, False, True, False,
False, True, False, True, True, False, True, True, True,
False])
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y_train, cmap='bwr')
<matplotlib.collections.PathCollection at 0x1820488e0>
X_tsne = tsne.fit_transform(X_pca)
plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_train, cmap='bwr')
<matplotlib.collections.PathCollection at 0x1820c04c0>
mean = X_train.mean(axis=0)
mean.shape
(40, 56, 3)
std = X_train.std(axis=0)
std.shape
(40, 56, 3)
plt.hist(np.ravel(std), bins=100);
plt.vlines(0.05, 0, 1000, colors='red')
<matplotlib.collections.LineCollection at 0x182048e50>
std[std<0.05] = 0
plt.hist(np.ravel(mean), bins=100);
plt.vlines(0.25, 0, 1000, colors='red')
<matplotlib.collections.LineCollection at 0x182203af0>
mean[mean<0.05] = 0
mask = (mean*std)!=0
X_zscore_tr = (X_train-mean)/std
X_zscore_te = (X_test-mean)/std
X_zscore_tr.shape
(307, 40, 56, 3)
X_zscore_tr[np.isnan(X_zscore_tr)]=0
X_zscore_te[np.isnan(X_zscore_te)]=0
X_zscore_tr[np.isinf(X_zscore_tr)]=0
X_zscore_te[np.isinf(X_zscore_te)]=0
And now we're good to go.
Now come the fun and tricky part. We need to specify the structure of our convoluted neural network. As a quick reminder, a convoluted neural network consists of some convolution layers, pooling layers, some flattening layers and some full connect layers:

So as a first step, let's import all modules that we need to create the keras model:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AvgPool2D, BatchNormalization
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
from tensorflow.keras.optimizers import Adam, SGD
As a next step, we should specify some of the model parameters that we want to be identical throughout the model:
# Get shape of input data
data_shape = tuple(X_train.shape[1:])
# Specify shape of convolution kernel
kernel_size = (3, 3)
# Specify number of output categories
n_classes = 2
Now comes the big part... the model, i.e. the structure of the neural network! We want to make clear that we're no experts in deep neural networks and therefore, the model below might not necessarily be a good model. But we chose it as it can be rather quickly estimated and has rather few parameters to estimate.
# Specify number of filters per layer
filters = 32
model = Sequential()
model.add(Conv2D(filters, kernel_size, activation='relu', input_shape=data_shape))
model.add(BatchNormalization())
model.add(MaxPooling2D())
filters *= 2
model.add(Conv2D(filters, kernel_size, activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
filters *= 2
model.add(Conv2D(filters, kernel_size, activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D())
filters *= 2
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(1024, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(n_classes, activation='softmax'))
2022-05-25 11:39:45.542956: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam', # swap out for sgd
metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 38, 54, 32) 896
batch_normalization (BatchN (None, 38, 54, 32) 128
ormalization)
max_pooling2d (MaxPooling2D (None, 19, 27, 32) 0
)
conv2d_1 (Conv2D) (None, 17, 25, 64) 18496
batch_normalization_1 (Batc (None, 17, 25, 64) 256
hNormalization)
max_pooling2d_1 (MaxPooling (None, 8, 12, 64) 0
2D)
conv2d_2 (Conv2D) (None, 6, 10, 128) 73856
batch_normalization_2 (Batc (None, 6, 10, 128) 512
hNormalization)
max_pooling2d_2 (MaxPooling (None, 3, 5, 128) 0
2D)
flatten (Flatten) (None, 1920) 0
dropout (Dropout) (None, 1920) 0
dense (Dense) (None, 1024) 1967104
batch_normalization_3 (Batc (None, 1024) 4096
hNormalization)
dropout_1 (Dropout) (None, 1024) 0
dense_1 (Dense) (None, 256) 262400
batch_normalization_4 (Batc (None, 256) 1024
hNormalization)
dropout_2 (Dropout) (None, 256) 0
dense_2 (Dense) (None, 64) 16448
batch_normalization_5 (Batc (None, 64) 256
hNormalization)
dropout_3 (Dropout) (None, 64) 0
dense_3 (Dense) (None, 2) 130
=================================================================
Total params: 2,345,602
Trainable params: 2,342,466
Non-trainable params: 3,136
_________________________________________________________________
That's what our model looks like! Cool!
The next step is now, of course, to fit our model to the training data. In our case we have two parameters that we can work with:
First: How many iterations of the model fitting should be computed
nEpochs = 100 # Increase this value for better results (i.e., more training)
Second: How many elements (volumes) should be considered at once for the updating of the weights?
batch_size = 32 # Increasing this value might speed up fitting
We will also define a log directory so that we can evaluate our model later on as best as possible.
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
So let's test the model:
%time fit = model.fit(X_zscore_tr, y_train, epochs=nEpochs, batch_size=batch_size,\
validation_split=0.2, callbacks=[tensorboard_callback])
Epoch 1/100 8/8 [==============================] - 2s 153ms/step - loss: 1.3926 - accuracy: 0.5020 - val_loss: 0.6548 - val_accuracy: 0.6452 Epoch 2/100 8/8 [==============================] - 1s 124ms/step - loss: 0.9597 - accuracy: 0.6041 - val_loss: 0.6324 - val_accuracy: 0.7258 Epoch 3/100 8/8 [==============================] - 1s 125ms/step - loss: 0.7451 - accuracy: 0.6939 - val_loss: 0.6340 - val_accuracy: 0.6129 Epoch 4/100 8/8 [==============================] - 1s 130ms/step - loss: 0.6463 - accuracy: 0.7102 - val_loss: 0.6153 - val_accuracy: 0.7419 Epoch 5/100 8/8 [==============================] - 1s 119ms/step - loss: 0.6486 - accuracy: 0.7224 - val_loss: 0.5925 - val_accuracy: 0.7742 Epoch 6/100 8/8 [==============================] - 1s 128ms/step - loss: 0.6002 - accuracy: 0.7265 - val_loss: 0.5590 - val_accuracy: 0.8065 Epoch 7/100 8/8 [==============================] - 1s 125ms/step - loss: 0.4804 - accuracy: 0.8041 - val_loss: 0.5460 - val_accuracy: 0.7742 Epoch 8/100 8/8 [==============================] - 1s 150ms/step - loss: 0.4568 - accuracy: 0.8245 - val_loss: 0.5388 - val_accuracy: 0.8065 Epoch 9/100 8/8 [==============================] - 1s 134ms/step - loss: 0.3595 - accuracy: 0.8571 - val_loss: 0.5440 - val_accuracy: 0.8387 Epoch 10/100 8/8 [==============================] - 1s 151ms/step - loss: 0.3120 - accuracy: 0.8816 - val_loss: 0.5462 - val_accuracy: 0.7742 Epoch 11/100 8/8 [==============================] - 1s 145ms/step - loss: 0.3281 - accuracy: 0.8612 - val_loss: 0.5513 - val_accuracy: 0.7097 Epoch 12/100 8/8 [==============================] - 1s 137ms/step - loss: 0.2679 - accuracy: 0.8980 - val_loss: 0.5719 - val_accuracy: 0.6129 Epoch 13/100 8/8 [==============================] - 1s 132ms/step - loss: 0.2698 - accuracy: 0.8939 - val_loss: 0.6865 - val_accuracy: 0.5161 Epoch 14/100 8/8 [==============================] - 1s 129ms/step - loss: 0.1656 - accuracy: 0.9265 - val_loss: 0.8108 - val_accuracy: 0.5161 Epoch 15/100 8/8 [==============================] - 1s 140ms/step - loss: 0.1833 - accuracy: 0.9347 - val_loss: 0.9223 - val_accuracy: 0.5000 Epoch 16/100 8/8 [==============================] - 1s 143ms/step - loss: 0.1648 - accuracy: 0.9429 - val_loss: 0.9505 - val_accuracy: 0.5000 Epoch 17/100 8/8 [==============================] - 1s 146ms/step - loss: 0.1720 - accuracy: 0.9429 - val_loss: 0.8843 - val_accuracy: 0.4839 Epoch 18/100 8/8 [==============================] - 1s 135ms/step - loss: 0.0830 - accuracy: 0.9714 - val_loss: 0.8557 - val_accuracy: 0.4839 Epoch 19/100 8/8 [==============================] - 1s 142ms/step - loss: 0.1872 - accuracy: 0.9306 - val_loss: 0.8204 - val_accuracy: 0.5161 Epoch 20/100 8/8 [==============================] - 1s 136ms/step - loss: 0.1205 - accuracy: 0.9633 - val_loss: 0.8903 - val_accuracy: 0.5000 Epoch 21/100 8/8 [==============================] - 1s 139ms/step - loss: 0.1398 - accuracy: 0.9429 - val_loss: 0.9782 - val_accuracy: 0.5000 Epoch 22/100 8/8 [==============================] - 1s 130ms/step - loss: 0.1030 - accuracy: 0.9592 - val_loss: 1.0515 - val_accuracy: 0.5000 Epoch 23/100 8/8 [==============================] - 1s 148ms/step - loss: 0.0858 - accuracy: 0.9633 - val_loss: 1.0652 - val_accuracy: 0.5161 Epoch 24/100 8/8 [==============================] - 1s 201ms/step - loss: 0.1055 - accuracy: 0.9510 - val_loss: 0.9735 - val_accuracy: 0.5323 Epoch 25/100 8/8 [==============================] - 1s 160ms/step - loss: 0.0699 - accuracy: 0.9714 - val_loss: 0.8374 - val_accuracy: 0.5806 Epoch 26/100 8/8 [==============================] - 1s 151ms/step - loss: 0.0457 - accuracy: 0.9878 - val_loss: 0.7892 - val_accuracy: 0.5968 Epoch 27/100 8/8 [==============================] - 1s 148ms/step - loss: 0.0761 - accuracy: 0.9755 - val_loss: 0.7068 - val_accuracy: 0.6129 Epoch 28/100 8/8 [==============================] - 1s 143ms/step - loss: 0.0493 - accuracy: 0.9837 - val_loss: 0.6480 - val_accuracy: 0.6452 Epoch 29/100 8/8 [==============================] - 1s 138ms/step - loss: 0.1099 - accuracy: 0.9592 - val_loss: 0.6517 - val_accuracy: 0.6129 Epoch 30/100 8/8 [==============================] - 1s 163ms/step - loss: 0.0493 - accuracy: 0.9878 - val_loss: 0.6734 - val_accuracy: 0.6129 Epoch 31/100 8/8 [==============================] - 1s 168ms/step - loss: 0.0295 - accuracy: 0.9959 - val_loss: 0.7009 - val_accuracy: 0.6452 Epoch 32/100 8/8 [==============================] - 1s 171ms/step - loss: 0.0872 - accuracy: 0.9714 - val_loss: 0.6840 - val_accuracy: 0.6452 Epoch 33/100 8/8 [==============================] - 1s 150ms/step - loss: 0.0352 - accuracy: 0.9959 - val_loss: 0.6040 - val_accuracy: 0.6774 Epoch 34/100 8/8 [==============================] - 1s 145ms/step - loss: 0.0497 - accuracy: 0.9796 - val_loss: 0.5444 - val_accuracy: 0.6935 Epoch 35/100 8/8 [==============================] - 1s 152ms/step - loss: 0.0891 - accuracy: 0.9673 - val_loss: 0.4700 - val_accuracy: 0.7742 Epoch 36/100 8/8 [==============================] - 1s 137ms/step - loss: 0.0537 - accuracy: 0.9837 - val_loss: 0.4536 - val_accuracy: 0.7903 Epoch 37/100 8/8 [==============================] - 1s 147ms/step - loss: 0.0260 - accuracy: 0.9959 - val_loss: 0.4615 - val_accuracy: 0.7903 Epoch 38/100 8/8 [==============================] - 1s 144ms/step - loss: 0.0186 - accuracy: 0.9959 - val_loss: 0.4437 - val_accuracy: 0.8226 Epoch 39/100 8/8 [==============================] - 1s 148ms/step - loss: 0.0380 - accuracy: 0.9837 - val_loss: 0.4077 - val_accuracy: 0.8065 Epoch 40/100 8/8 [==============================] - 1s 143ms/step - loss: 0.0128 - accuracy: 1.0000 - val_loss: 0.3810 - val_accuracy: 0.8387 Epoch 41/100 8/8 [==============================] - 1s 135ms/step - loss: 0.0376 - accuracy: 0.9837 - val_loss: 0.3721 - val_accuracy: 0.8548 Epoch 42/100 8/8 [==============================] - 1s 147ms/step - loss: 0.0119 - accuracy: 1.0000 - val_loss: 0.3940 - val_accuracy: 0.8548 Epoch 43/100 8/8 [==============================] - 1s 134ms/step - loss: 0.0182 - accuracy: 1.0000 - val_loss: 0.4053 - val_accuracy: 0.8548 Epoch 44/100 8/8 [==============================] - 1s 136ms/step - loss: 0.0365 - accuracy: 0.9878 - val_loss: 0.3956 - val_accuracy: 0.8387 Epoch 45/100 8/8 [==============================] - 1s 143ms/step - loss: 0.0146 - accuracy: 0.9959 - val_loss: 0.3969 - val_accuracy: 0.8387 Epoch 46/100 8/8 [==============================] - 1s 156ms/step - loss: 0.0413 - accuracy: 0.9837 - val_loss: 0.3770 - val_accuracy: 0.8387 Epoch 47/100 8/8 [==============================] - 1s 156ms/step - loss: 0.0255 - accuracy: 0.9918 - val_loss: 0.3671 - val_accuracy: 0.8548 Epoch 48/100 8/8 [==============================] - 1s 183ms/step - loss: 0.0245 - accuracy: 0.9918 - val_loss: 0.3529 - val_accuracy: 0.8387 Epoch 49/100 8/8 [==============================] - 1s 151ms/step - loss: 0.0537 - accuracy: 0.9796 - val_loss: 0.3366 - val_accuracy: 0.8710 Epoch 50/100 8/8 [==============================] - 1s 140ms/step - loss: 0.0158 - accuracy: 1.0000 - val_loss: 0.3621 - val_accuracy: 0.8387 Epoch 51/100 8/8 [==============================] - 1s 142ms/step - loss: 0.0309 - accuracy: 0.9878 - val_loss: 0.3704 - val_accuracy: 0.8387 Epoch 52/100 8/8 [==============================] - 1s 156ms/step - loss: 0.0151 - accuracy: 0.9918 - val_loss: 0.3585 - val_accuracy: 0.8548 Epoch 53/100 8/8 [==============================] - 1s 154ms/step - loss: 0.0172 - accuracy: 0.9918 - val_loss: 0.3300 - val_accuracy: 0.8871 Epoch 54/100 8/8 [==============================] - 1s 163ms/step - loss: 0.0300 - accuracy: 0.9878 - val_loss: 0.3353 - val_accuracy: 0.8871 Epoch 55/100 8/8 [==============================] - 1s 145ms/step - loss: 0.0144 - accuracy: 0.9918 - val_loss: 0.3379 - val_accuracy: 0.8710 Epoch 56/100 8/8 [==============================] - 1s 161ms/step - loss: 0.0313 - accuracy: 0.9878 - val_loss: 0.3412 - val_accuracy: 0.8710 Epoch 57/100 8/8 [==============================] - 1s 141ms/step - loss: 0.0272 - accuracy: 0.9918 - val_loss: 0.3223 - val_accuracy: 0.8710 Epoch 58/100 8/8 [==============================] - 1s 141ms/step - loss: 0.0123 - accuracy: 0.9959 - val_loss: 0.3440 - val_accuracy: 0.8548 Epoch 59/100 8/8 [==============================] - 1s 163ms/step - loss: 0.0206 - accuracy: 0.9959 - val_loss: 0.3426 - val_accuracy: 0.8387 Epoch 60/100 8/8 [==============================] - 1s 145ms/step - loss: 0.0411 - accuracy: 0.9796 - val_loss: 0.3850 - val_accuracy: 0.8226 Epoch 61/100 8/8 [==============================] - 1s 143ms/step - loss: 0.0720 - accuracy: 0.9673 - val_loss: 0.5091 - val_accuracy: 0.7742 Epoch 62/100 8/8 [==============================] - 1s 133ms/step - loss: 0.0228 - accuracy: 0.9878 - val_loss: 0.5382 - val_accuracy: 0.8065 Epoch 63/100 8/8 [==============================] - 1s 135ms/step - loss: 0.0166 - accuracy: 0.9959 - val_loss: 0.5027 - val_accuracy: 0.8387 Epoch 64/100 8/8 [==============================] - 1s 140ms/step - loss: 0.0239 - accuracy: 0.9878 - val_loss: 0.4835 - val_accuracy: 0.8548 Epoch 65/100 8/8 [==============================] - 1s 141ms/step - loss: 0.0501 - accuracy: 0.9878 - val_loss: 0.4564 - val_accuracy: 0.8387 Epoch 66/100 8/8 [==============================] - 1s 162ms/step - loss: 0.0212 - accuracy: 0.9878 - val_loss: 0.4216 - val_accuracy: 0.8387 Epoch 67/100 8/8 [==============================] - 1s 169ms/step - loss: 0.0235 - accuracy: 0.9918 - val_loss: 0.3966 - val_accuracy: 0.8548 Epoch 68/100 8/8 [==============================] - 1s 155ms/step - loss: 0.0395 - accuracy: 0.9918 - val_loss: 0.4118 - val_accuracy: 0.8226 Epoch 69/100 8/8 [==============================] - 1s 173ms/step - loss: 0.0442 - accuracy: 0.9837 - val_loss: 0.4085 - val_accuracy: 0.8548 Epoch 70/100 8/8 [==============================] - 1s 164ms/step - loss: 0.0220 - accuracy: 0.9918 - val_loss: 0.4225 - val_accuracy: 0.8548 Epoch 71/100 8/8 [==============================] - 1s 162ms/step - loss: 0.0195 - accuracy: 0.9959 - val_loss: 0.3731 - val_accuracy: 0.8710 Epoch 72/100 8/8 [==============================] - 1s 172ms/step - loss: 0.0095 - accuracy: 0.9959 - val_loss: 0.3665 - val_accuracy: 0.8387 Epoch 73/100 8/8 [==============================] - 1s 144ms/step - loss: 0.0125 - accuracy: 1.0000 - val_loss: 0.3816 - val_accuracy: 0.8548 Epoch 74/100 8/8 [==============================] - 1s 160ms/step - loss: 0.0411 - accuracy: 0.9837 - val_loss: 0.4743 - val_accuracy: 0.8387 Epoch 75/100 8/8 [==============================] - 1s 155ms/step - loss: 0.0286 - accuracy: 0.9918 - val_loss: 0.5085 - val_accuracy: 0.8387 Epoch 76/100 8/8 [==============================] - 1s 147ms/step - loss: 0.0501 - accuracy: 0.9837 - val_loss: 0.5491 - val_accuracy: 0.8548 Epoch 77/100 8/8 [==============================] - 1s 153ms/step - loss: 0.0445 - accuracy: 0.9878 - val_loss: 0.5381 - val_accuracy: 0.8548 Epoch 78/100 8/8 [==============================] - 1s 180ms/step - loss: 0.0085 - accuracy: 0.9959 - val_loss: 0.5165 - val_accuracy: 0.8548 Epoch 79/100 8/8 [==============================] - 1s 152ms/step - loss: 0.0455 - accuracy: 0.9796 - val_loss: 0.4389 - val_accuracy: 0.8871 Epoch 80/100 8/8 [==============================] - 1s 177ms/step - loss: 0.0348 - accuracy: 0.9837 - val_loss: 0.4085 - val_accuracy: 0.8871 Epoch 81/100 8/8 [==============================] - 1s 154ms/step - loss: 0.0270 - accuracy: 0.9918 - val_loss: 0.3844 - val_accuracy: 0.8871 Epoch 82/100 8/8 [==============================] - 1s 173ms/step - loss: 0.0543 - accuracy: 0.9837 - val_loss: 0.3827 - val_accuracy: 0.9194 Epoch 83/100 8/8 [==============================] - 1s 187ms/step - loss: 0.0617 - accuracy: 0.9796 - val_loss: 0.4587 - val_accuracy: 0.8871 Epoch 84/100 8/8 [==============================] - 1s 176ms/step - loss: 0.0086 - accuracy: 1.0000 - val_loss: 0.5762 - val_accuracy: 0.8710 Epoch 85/100 8/8 [==============================] - 2s 210ms/step - loss: 0.0156 - accuracy: 0.9959 - val_loss: 0.5980 - val_accuracy: 0.8710 Epoch 86/100 8/8 [==============================] - 2s 196ms/step - loss: 0.0386 - accuracy: 0.9878 - val_loss: 0.4881 - val_accuracy: 0.8871 Epoch 87/100 8/8 [==============================] - 1s 184ms/step - loss: 0.0195 - accuracy: 0.9959 - val_loss: 0.5051 - val_accuracy: 0.9032 Epoch 88/100 8/8 [==============================] - 1s 188ms/step - loss: 0.0117 - accuracy: 0.9959 - val_loss: 0.5978 - val_accuracy: 0.8387 Epoch 89/100 8/8 [==============================] - 2s 201ms/step - loss: 0.0134 - accuracy: 0.9959 - val_loss: 0.7096 - val_accuracy: 0.7903 Epoch 90/100 8/8 [==============================] - 2s 207ms/step - loss: 0.0282 - accuracy: 0.9878 - val_loss: 0.8517 - val_accuracy: 0.7258 Epoch 91/100 8/8 [==============================] - 1s 195ms/step - loss: 0.0242 - accuracy: 0.9918 - val_loss: 1.0758 - val_accuracy: 0.6613 Epoch 92/100 8/8 [==============================] - 1s 190ms/step - loss: 0.0252 - accuracy: 0.9959 - val_loss: 0.9523 - val_accuracy: 0.6774 Epoch 93/100 8/8 [==============================] - 1s 201ms/step - loss: 0.0137 - accuracy: 0.9959 - val_loss: 0.7976 - val_accuracy: 0.7097 Epoch 94/100 8/8 [==============================] - 1s 199ms/step - loss: 0.0155 - accuracy: 0.9959 - val_loss: 0.7096 - val_accuracy: 0.7742 Epoch 95/100 8/8 [==============================] - 1s 176ms/step - loss: 0.0236 - accuracy: 0.9918 - val_loss: 0.5864 - val_accuracy: 0.8387 Epoch 96/100 8/8 [==============================] - 1s 177ms/step - loss: 0.0190 - accuracy: 0.9918 - val_loss: 0.5050 - val_accuracy: 0.8548 Epoch 97/100 8/8 [==============================] - 1s 174ms/step - loss: 0.0128 - accuracy: 0.9959 - val_loss: 0.4799 - val_accuracy: 0.8710 Epoch 98/100 8/8 [==============================] - 1s 170ms/step - loss: 0.0228 - accuracy: 0.9959 - val_loss: 0.4938 - val_accuracy: 0.8548 Epoch 99/100 8/8 [==============================] - 2s 208ms/step - loss: 0.0071 - accuracy: 1.0000 - val_loss: 0.5290 - val_accuracy: 0.8226 Epoch 100/100 8/8 [==============================] - 2s 200ms/step - loss: 0.0247 - accuracy: 0.9878 - val_loss: 0.5363 - val_accuracy: 0.8065 CPU times: user 3min 57s, sys: 2min 12s, total: 6min 9s Wall time: 1min 59s
Let's take a look at the loss and accuracy values during the different epochs, starting with accuracy values:
fig = plt.figure(figsize=(10, 4))
epoch = np.arange(nEpochs) + 1
fontsize = 16
plt.plot(epoch, fit.history['accuracy'], marker="o", linewidth=2,
color="steelblue", label="acc")
plt.plot(epoch, fit.history['val_accuracy'], marker="o", linewidth=2,
color="orange", label="val_acc")
plt.xlabel('epoch', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(frameon=False, fontsize=16);
Given that we are running this interactively in a jupyter notebook, we can make use of its capabilities and create an interactive graph using plotly:
def accuracy_epoch_plotly():
fig = go.Figure()
fig.add_trace(go.Scatter(x=epoch, y=fit.history['accuracy'],
mode='lines+markers',
name='acc'))
fig.add_trace(go.Scatter(x=epoch, y=fit.history['val_accuracy'],
mode='lines+markers',
name='val_acc'))
fig.update_layout(
title={
'text': "Accuracy per epoch",
'y':0.95,
'x':0.5,
'xanchor': 'center',
'yanchor': 'top'},
xaxis_title="Epoch",
yaxis_title="Accuracy",
legend_title="Type",
template='plotly_white'
)
fig.show()
accuracy_epoch_plotly()
Next, we check the loss values, at first via a static plot:
fig = plt.figure(figsize=(10, 4))
epoch = np.arange(nEpochs) + 1
fontsize = 16
plt.plot(epoch, fit.history['loss'], marker="o", linewidth=2,
color="steelblue", label="loss")
plt.plot(epoch, fit.history['val_loss'], marker="o", linewidth=2,
color="orange", label="val_loss")
plt.xlabel('epoch', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(frameon=False, fontsize=16);
and second via an interactive plot:
def loss_epoch_plotly():
fig = go.Figure()
fig.add_trace(go.Scatter(x=epoch, y=fit.history['loss'],
mode='lines+markers',
name='loss'))
fig.add_trace(go.Scatter(x=epoch, y=fit.history['val_loss'],
mode='lines+markers',
name='val_loss'))
fig.update_layout(
title={
'text': "Loss per epoch",
'y':0.95,
'x':0.5,
'xanchor': 'center',
'yanchor': 'top'},
xaxis_title="Epoch",
yaxis_title="Loss",
legend_title="Type",
template='plotly_white'
)
fig.show()
loss_epoch_plotly()
Great, it seems that accuracy is constantly increasing and the loss is continuing to drop. But how well is our model doing on the test data?
evaluation = model.evaluate(X_zscore_te, y_test)
print('Loss in Test set: %.02f' % (evaluation[0]))
print('Accuracy in Test set: %.02f' % (evaluation[1] * 100))
3/3 [==============================] - 0s 14ms/step - loss: 0.6017 - accuracy: 0.8182 Loss in Test set: 0.60 Accuracy in Test set: 81.82
We can also evaluate the model in more detail via obtaining a confusion matrix which will provide as with more information concerning the sensitivity and specificity of our model. After getting the predicted and true labels
y_pred = np.argmax(model.predict(X_zscore_te), axis=1)
y_pred
array([1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0,
0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1,
0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1])
y_true = y_test * 1
y_true
array([1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0,
1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1])
We can compute the confusion matrix and plot it:
from sklearn.metrics import confusion_matrix
import pandas as pd
class_labels = ['closed', 'open']
cm = pd.DataFrame(confusion_matrix(y_true, y_pred), index=class_labels, columns=class_labels)
sns.heatmap(cm, square=True, annot=True);
Again, an interactive plot might be nice as well:
def confusion_matrix_plotly():
z_text = [[str(y) for y in x] for x in cm.to_numpy()]
fig = ff.create_annotated_heatmap(cm.to_numpy(), x=class_labels, y=class_labels, annotation_text=z_text, colorscale='Magma')
# add custom xaxis title
fig.add_annotation(dict(font=dict(color="black",size=14),
x=0.5,
y=-0.15,
showarrow=False,
text="Predicted value",
xref="paper",
yref="paper"))
# add custom yaxis title
fig.add_annotation(dict(font=dict(color="black",size=14),
x=-0.1,
y=0.45,
showarrow=False,
text="Real value",
textangle=-90,
xref="paper",
yref="paper"))
fig.show()
confusion_matrix_plotly()
What are the predicted values of the test set?
y_pred = model.predict(X_zscore_te)
y_pred[:10,:]
array([[6.3193397e-04, 9.9936813e-01],
[7.3998719e-01, 2.6001284e-01],
[2.6101706e-04, 9.9973899e-01],
[9.9992204e-01, 7.7959899e-05],
[4.1116646e-04, 9.9958879e-01],
[9.9938703e-01, 6.1293563e-04],
[9.6927714e-01, 3.0722868e-02],
[2.5064608e-03, 9.9749351e-01],
[9.6933562e-01, 3.0664392e-02],
[2.3600699e-01, 7.6399302e-01]], dtype=float32)
As you can see, those values can be between 0 and 1.
fig = plt.figure(figsize=(6, 4))
fontsize = 16
plt.hist(y_pred[:,0], bins=16, label='eyes closed')
plt.hist(y_pred[:,1], bins=16, label='eyes open');
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(frameon=False, fontsize=16);
As usual, we also generate an interactive plot:
fig = go.Figure()
fig.add_trace(go.Histogram(x=y_pred[:,0],name='eyes closed', nbinsx=16, marker_color='blue'))
fig.add_trace(go.Histogram(x=y_pred[:,1],name='eyes open', nbinsx=16, marker_color='orange'))
fig.update_layout(barmode='stack', template='plotly_white')
fig.show()
The more both distributions are distributed around chance level, the weaker your model is.
Note: Keep in mind that we trained the whole model only on one split of test and training data. Ideally, you would repeat this process many times so that your results become less dependent on what kind of split you did.
Finally, as a cool additional feature: We can now visualize the individual filters of the hidden layers. So let's get to it:
# Aggregate the layers
layer_dict = dict([(layer.name, layer) for layer in model.layers])
from tensorflow.keras import backend as K
# Specify a function that visualized the layers
def show_activation(layer_name):
layer_output = layer_dict[layer_name].output
fn = K.function([model.input], [layer_output])
inp = X_train[0:1]
this_hidden = fn([inp])[0]
# plot the activations, 8 filters per row
plt.figure(figsize=(16,8))
nFilters = this_hidden.shape[-1]
nColumn = 8 if nFilters >= 8 else nFilters
for i in range(nFilters):
plt.subplot(int(nFilters / int(nColumn)), int(nColumn), i+1)
plt.imshow(this_hidden[0,:,:,i], cmap='magma', interpolation='nearest')
plt.axis('off')
return
Now we can plot the filters of the hidden layers:
layer_dict
{'conv2d': <keras.layers.convolutional.Conv2D at 0x182355360>,
'batch_normalization': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x182364d90>,
'max_pooling2d': <keras.layers.pooling.MaxPooling2D at 0x182356260>,
'conv2d_1': <keras.layers.convolutional.Conv2D at 0x1823576a0>,
'batch_normalization_1': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x182355db0>,
'max_pooling2d_1': <keras.layers.pooling.MaxPooling2D at 0x182355750>,
'conv2d_2': <keras.layers.convolutional.Conv2D at 0x1824b9de0>,
'batch_normalization_2': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824b9c60>,
'max_pooling2d_2': <keras.layers.pooling.MaxPooling2D at 0x1824bb160>,
'flatten': <keras.layers.core.flatten.Flatten at 0x1824bb0a0>,
'dropout': <keras.layers.core.dropout.Dropout at 0x1824bb910>,
'dense': <keras.layers.core.dense.Dense at 0x1824d9150>,
'batch_normalization_3': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824b94b0>,
'dropout_1': <keras.layers.core.dropout.Dropout at 0x1824d9930>,
'dense_1': <keras.layers.core.dense.Dense at 0x1824d9510>,
'batch_normalization_4': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824db130>,
'dropout_2': <keras.layers.core.dropout.Dropout at 0x1824f0cd0>,
'dense_2': <keras.layers.core.dense.Dense at 0x1824db340>,
'batch_normalization_5': <keras.layers.normalization.batch_normalization.BatchNormalization at 0x1824f1060>,
'dropout_3': <keras.layers.core.dropout.Dropout at 0x1824f2ef0>,
'dense_3': <keras.layers.core.dense.Dense at 0x1824f30d0>}
show_activation('conv2d_1')
show_activation('conv2d_2')
The classification of the training set gets incredibly high, while the validation set also reaches a reasonable accuracy level above 80. Nonetheless, by only investigating a slab of our fMRI dataset, we might have missed out on some important additional parameters.
An alternative solution might be to use 3D convoluted neural networks. But keep in mind that they will have even more parameters and probably take much longer to fit the model to the training data. Having said so, let's get to it.
Going back to graphics, outputs and interactive instances of jupyter notebook, we can even go crazier and actually include a running tensorboard instance to enable interactive evaluation of our model (the future is now):
%load_ext tensorboard
%tensorboard --logdir logs
The tensorboard extension is already loaded. To reload it, use: %reload_ext tensorboard